import torch
import torch.nn as nn
import torch_geometric.nn as gnn
import torch_geometric
import pickle
import torch
from torch.nn import Parameter
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.data import Data, DataLoader

from sklearn.metrics import roc_auc_score, average_precision_score
from torch_geometric.utils import (negative_sampling, remove_self_loops,
								   add_self_loops)
import numpy as np
from torch_geometric.utils import to_dense_batch
from torch import optim
import torch_geometric.nn as geom_nn
import random 
import copy
from predictrel import RelPredictor


magenta_size = 256
num_features = magenta_size
num_symmetries = 23
num_edge_features = 49#
#Message passing layers - standard from torch_geometric library
class NNConv(MessagePassing):
	r"""The continuous kernel-based convolutional operator from the
	`"Neural Message Passing for Quantum Chemistry"
	<https://arxiv.org/abs/1704.01212>`_ paper.
	This convolution is also known as the edge-conditioned convolution from the
	`"Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on
	Graphs" <https://arxiv.org/abs/1704.02901>`_ paper (see
	:class:`torch_geometric.nn.conv.ECConv` for an alias):

	.. math::
		\mathbf{x}^{\prime}_i = \mathbf{\Theta} \mathbf{x}_i +
		\sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \cdot
		h_{\mathbf{\Theta}}(\mathbf{e}_{i,j}),

	where :math:`h_{\mathbf{\Theta}}` denotes a neural network, *.i.e.*
	a MLP.

	Args:
		in_channels (int): Size of each input sample.
		out_channels (int): Size of each output sample.
		nn (torch.nn.Module): A neural network :math:`h_{\mathbf{\Theta}}` that
			maps edge features :obj:`edge_attr` of shape :obj:`[-1,
			num_edge_features]` to shape
			:obj:`[-1, in_channels * out_channels]`, *e.g.*, defined by
			:class:`torch.nn.Sequential`.
		aggr (string, optional): The aggregation scheme to use
			(:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`).
			(default: :obj:`"add"`)
		root_weight (bool, optional): If set to :obj:`False`, the layer will
			not add the transformed root node features to the output.
			(default: :obj:`True`)
		bias (bool, optional): If set to :obj:`False`, the layer will not learn
			an additive bias. (default: :obj:`True`)
		**kwargs (optional): Additional arguments of
			:class:`torch_geometric.nn.conv.MessagePassing`.
	"""

	def __init__(self, num_features, num_edge_features,in_channels,out_channels,aggr='add',root_weight=True,bias=True,**kwargs):
		super(NNConv, self).__init__(aggr=aggr, **kwargs)
		self.n_features = num_features
		self.in_channels = in_channels
		self.out_channels = 20
		out_channels = 20
		self.nn = nn.Sequential(nn.Linear(num_edge_features, self.n_features*out_channels))
		self.aggrnn = nn.Sequential(nn.Linear(num_features + out_channels, num_features), nn.ReLU(), nn.Linear(num_features,num_features))

		if root_weight:
			self.root = Parameter(torch.randn(in_channels, out_channels))
		else:
			self.register_parameter('root', None)

		if bias:
			self.bias = nn.Parameter(torch.randn(out_channels)) 
		else:
			self.register_parameter('bias', None)

	def forward(self, x, edge_index, edge_attr, prin = False):
		x = x.unsqueeze(-1) if x.dim() == 1 else x
		pseudo = edge_attr.unsqueeze(-1) if edge_attr.dim() == 1 else edge_attr
		#print((x.shape, edge_attr.shape, edge_index.shape))
		a = self.propagate(edge_index, x=x, pseudo=pseudo)
		if prin:
			print(a[0,:10])
		return a


	def message(self, x_j, pseudo):
		weight = self.nn(pseudo).view(-1, self.n_features, self.out_channels)
		#print((weight.shape, x_j.shape))
		#print(weight.shape)
		#print(x_j.shape)
		return nn.Sigmoid()(torch.matmul(x_j.unsqueeze(1), weight).squeeze(1))


	def update(self, aggr_out, x):
		#print((x.shape, aggr_out.shape))
		y = torch.cat((aggr_out, x), axis=1)
		#return nn.Sigmoid()(aggr_out + x)
		return self.aggrnn(y)

	def __repr__(self):
		return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
								   self.out_channels)
"""
def cov(m, rowvar=True, inplace=False):
	'''Estimate a covariance matrix given data.

	Covariance indicates the level to which two variables vary together.
	If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`,
	then the covariance matrix element `C_{ij}` is the covariance of
	`x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`.

	Args:
	    m: A 1-D or 2-D array containing multiple variables and observations.
	        Each row of `m` represents a variable, and each column a single
	        observation of all those variables.
	    rowvar: If `rowvar` is True, then each row represents a
	        variable, with observations in the columns. Otherwise, the
	        relationship is transposed: each column represents a variable,
	        while the rows contain observations.

	Returns:
	    The covariance matrix of the variables.
	'''
	# m = m.type(torch.double)  # uncomment this line if desired
	fact = 1.0 / (m.size(1) - 1)
	mt = (m - torch.mean(m, dim=1, keepdim=True)).t()  # if complex: mt = m.t().conj()
	return fact * m.matmul(mt).squeeze()
"""
def cov(m, y=None):
    if y is not None:
        m = torch.cat((m, y), dim=0)
    m_exp = torch.mean(m, dim=1)
    x = m - m_exp[:, None]
    cov = 1 / (x.size(1) - 1) * x.mm(x.t())
    return cov

class Encoder(nn.Module):
	def __init__(self):
		super(Encoder, self).__init__()
		self.conv1 = NNConv(num_features, num_edge_features, num_edge_features, num_features)
		self.conv2 = NNConv(num_features, num_edge_features, num_edge_features, num_features)
		self.conv3 = NNConv(num_features, num_edge_features, num_edge_features, num_features)
		self.conv4 = NNConv(num_features, num_edge_features, num_edge_features, num_features)
		#self.conv5 = NNConv(num_features, num_edge_features, num_edge_features, num_features)
		self.mu = nn.Linear(num_features,200)
		self.logvar = nn.Linear(num_features,200)
		#self.autoenc = nn.Sequential(nn.Linear(512,256), nn.ReLU(), nn.Linear(256,num_features))
	def forward(self, x, edge_ind, edge_attr):
		#x = self.autoenc(x)
		x = self.conv1(x, edge_ind, edge_attr)
		x = self.conv2(x, edge_ind, edge_attr)
		#x = self.conv3(x, edge_ind, edge_attr)
		#x = self.conv4(x, edge_ind, edge_attr)
		#x = self.conv5(x, edge_ind, edge_attr)
		x_mu = self.mu(x)
		x_log = self.logvar(x)
		return x_mu, x_log

class Decoder(nn.Module):
	def __init__(self):
		super(Decoder, self).__init__()
		self.conv1 = NNConv(num_features, num_edge_features, num_edge_features, num_features)
		self.conv2 = NNConv(num_features, num_edge_features, num_edge_features, num_features)
		self.conv3 = NNConv(num_features, num_edge_features, num_edge_features, num_features)
		self.conv4 = NNConv(num_features, num_edge_features, num_edge_features, num_features)
		self.conv5 = NNConv(num_features, num_edge_features, num_edge_features, num_features)
		self.autodec = nn.Sequential(nn.Linear(200,num_features), nn.ReLU(), nn.Linear(num_features,num_features))
		#self.change_from_sig = nn.Sequential(nn.Linear(num_features, num_features), nn.ReLU(), nn.Linear(num_features, num_features))
	def forward(self, x, edge_ind, edge_attr):
		x = self.autodec(x)
		x = self.conv1(x, edge_ind, edge_attr)
		x = self.conv2(x, edge_ind, edge_attr)
		#x = self.conv3(x, edge_ind, edge_attr)
		#x = self.conv4(x, edge_ind, edge_attr)
		#x = self.conv5(x, edge_ind, edge_attr)
		#$return (self.change_from_sig(x))
		return x

class VAE(nn.Module):
	def __init__(self):
		super(VAE, self).__init__()
		self.encoder = Encoder()
		self.decoder = Decoder()
		self.rel_predict = RelPredictor()
		self.rel_predict.load_state_dict(torch.load("graphnn/relpredictor.pth"))
	def reparameterize(self, mu, logvar):
		std = torch.exp(0.5*logvar)
		eps = torch.randn_like(std)
		return mu + eps*std
	def forward(self, x, edge_index, edge_attr):
		x_mu, x_log = self.encoder(x, edge_index, edge_attr)
		z = self.reparameterize(x_mu, x_log)
		x = self.decoder(z, edge_index, edge_attr)
		edge_x = (torch.cat([x[edge_index[0,:]][:,:magenta_size], x[edge_index[1,:]][:,:magenta_size]], axis = 1))
		attr_predict = self.rel_predict(edge_x)

		return x, z, x_mu, x_log, attr_predict

